{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Quantization Quickstart\n\nQuantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.\n\nIn NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.\nHere we use `QATQuantizer` as an example to show the usage of quantization in NNI.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Preparation\n\nIn this tutorial, we use a simple model and pre-train on MNIST dataset.\nIf you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import time\nfrom typing import Callable, Union, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.optim import Optimizer, SGD\nfrom torch.utils.data import DataLoader\nfrom torch import Tensor\n\nfrom nni.common.types import SCHEDULER"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Define the model\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Mnist(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)\n        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)\n        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)\n        self.fc2 = torch.nn.Linear(500, 10)\n        self.relu1 = torch.nn.ReLU6()\n        self.relu2 = torch.nn.ReLU6()\n        self.relu3 = torch.nn.ReLU6()\n        self.max_pool1 = torch.nn.MaxPool2d(2, 2)\n        self.max_pool2 = torch.nn.MaxPool2d(2, 2)\n        self.batchnorm1 = torch.nn.BatchNorm2d(20)\n\n    def forward(self, x):\n        x = self.relu1(self.batchnorm1(self.conv1(x)))\n        x = self.max_pool1(x)\n        x = self.relu2(self.conv2(x))\n        x = self.max_pool2(x)\n        x = x.view(-1, 4 * 4 * 50)\n        x = self.relu3(self.fc1(x))\n        x = self.fc2(x)\n        return F.log_softmax(x, dim=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Create training and evaluation dataloader\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom torchvision.datasets import MNIST\n\nMNIST(root='data/mnist', train=True, download=True)\nMNIST(root='data/mnist', train=False, download=True)\ntransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\nmnist_train = MNIST(root='data/mnist', train=True, transform=transform)\ntrain_dataloader = DataLoader(mnist_train, batch_size=64)\nmnist_test = MNIST(root='data/mnist', train=False, transform=transform)\ntest_dataloader = DataLoader(mnist_test, batch_size=1000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Define training and evaluation functions\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n\ndef training_step(batch, model) -> Tensor:\n    x, y = batch[0].to(device), batch[1].to(device)\n    logits = model(x)\n    loss: torch.Tensor = F.nll_loss(logits, y)\n    return loss\n\n\ndef training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,\n                   max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):\n    model.train()\n    max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100\n    current_steps = 0\n\n    # training\n    for epoch in range(max_epochs):\n        print(f'Epoch {epoch} start!')\n        for batch in train_dataloader:\n            optimizer.zero_grad()\n            loss = training_step(batch, model)\n            loss.backward()\n            optimizer.step()\n            current_steps += 1\n            if max_steps and current_steps == max_steps:\n                return\n        if scheduler is not None:\n            scheduler.step()\n\n\ndef evaluating_model(model: torch.nn.Module):\n    model.eval()\n    # testing\n    correct = 0\n    with torch.no_grad():\n        for x, y in test_dataloader:\n            x, y = x.to(device), y.to(device)\n            logits = model(x)\n            preds = torch.argmax(logits, dim=1)\n            correct += preds.eq(y.view_as(preds)).sum().item()\n    return correct / len(mnist_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Pre-train and evaluate the model on MNIST dataset\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = Mnist().to(device)\noptimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n\nstart = time.time()\ntraining_model(model, optimizer, training_step, None, None, 5)\nprint(f'pure training 5 epochs: {time.time() - start}s')\nstart = time.time()\nacc = evaluating_model(model)\nprint(f'pure evaluating: {time.time() - start}s    Acc.: {acc}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Quantizing Model\n\nInitialize a `config_list`.\nDetailed about how to write ``config_list`` please refer :doc:`Config Specification <../compression/config_list>`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nni\nfrom nni.compression.quantization import QATQuantizer\nfrom nni.compression.utils import TorchEvaluator\n\n\noptimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\nevaluator = TorchEvaluator(training_model, optimizer, training_step)  # type: ignore\n\nconfig_list = [{\n    'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],\n    'target_names': ['_input_', 'weight', '_output_'],\n    'quant_dtype': 'int8',\n    'quant_scheme': 'affine',\n    'granularity': 'default',\n},{\n    'op_names': ['relu1', 'relu2', 'relu3'],\n    'target_names': ['_output_'],\n    'quant_dtype': 'int8',\n    'quant_scheme': 'affine',\n    'granularity': 'default',\n}]\n\nquantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))\nreal_input = next(iter(train_dataloader))[0].to(device)\nquantizer.track_forward(real_input)\n\nstart = time.time()\n_, calibration_config = quantizer.compress(None, max_epochs=5)\nprint(f'pure training 5 epochs: {time.time() - start}s')\n\nprint(calibration_config)\nstart = time.time()\nacc = evaluating_model(model)\nprint(f'quantization evaluating: {time.time() - start}s    Acc.: {acc}')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}